Skip to main content

Save & Load Model Weights

Summary

  • Save weights
  • Load weights

Content

Save Weights

  • Checkpoint callback can be created as follows
checkpoint_path = "models/cnn.ckpt"

model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path,
monitor="val_loss",
save_weights_only=True,
save_best_only=False,
save_freq="epoch",
verbose=1,
)
  • Callback should be registerd to fit when training the model
model.fit(
train_data,
epochs=4,
steps_per_epoch=1,
validation_data=test_data,
callbacks=[model_checkpoint],
)
  • Saved checkpoints are created like below

saved checkpoints

Load Weights

  • Create a new model of the same architecture & compile the model
model_clone = tf.keras.models.clone_model(model)
model_clone.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"],
)
  • Load the saved weights to the new model
model_clone.evaluate(test_data)
# loss: 2.3026 - accuracy: 0.1036

model_clone.load_weights(checkpoint_path)
model_clone.evaluate(test_data)
# loss: 2.1837 - accuracy: 0.2152

model.evaluate(test_data)
# loss: 2.1837 - accuracy: 0.2152

# as you can see, after weights has been loaded to cloned model, both models
# prodeces the same predictions